from pathlib import Pathimport matplotlib.pyplot as pltimport numpy as npimport osimport cv2from tqdm.notebook import tqdmfrom skimage.registration import optical_flow_tvl1, optical_flow_ilkfrom skimage.transform import warpfrom skimage.color import rgb2grayfrom skimage.metrics import structural_similarity as ssimfrom skimage.metrics import normalized_root_mse as nrmse
Next, let’s write a function to handle retrieving the images. Since we only want to register fundus images to the same eye, we specify which patient and which laterality we want to load in:
Code
def retrieve_images(patient_id ='156518', laterality ='L'):# Set the root directory for the patient data root_dir = Path(f'../data/{patient_id}')# Get the list of image filenames for the left eye image_filenames = [f for f in os.listdir(root_dir) iff'{laterality}.png'in f]# Read the images into a list images = [cv2.imread(str(root_dir / f)) for f in image_filenames]# Convert the images to grayscale gray_images = [rgb2gray(img) for img in images]# Register all images to the first image template = gray_images[0]# Remove invalid images final_images = [x for x in gray_images[1:] if x.shape == template.shape]return final_images, template
When evaluating our registration algorithm, our evaluation metric will be some function that computes the distance between the registered images and the template image. We want to be able to track a few of these metrics. Some common ones include:
L1 loss, also known as mean absolute error, measures the average magnitude of the element-wise differences between two images. It is robust to outliers and gives equal weight to all pixels, making it a good choice for image registration.
RMSE, or root mean square error, is the square root of the mean of the squared differences between two images. It gives more weight to larger differences, making it sensitive to outliers. RMSE is commonly used in image registration to measure the overall difference between two images.
Normalised cross-correlation is a measure of the similarity between two images, taking into account their intensities. It is normalised to ensure that the result is between -1 and 1, where 1 indicates a perfect match. Normalised cross-correlation is often used in image registration to assess the quality of the registration, especially when dealing with images with different intensities.
Similarity is a measure of the overlap between two images, taking into account both the intensities and spatial information. Common similarity metrics used in image registration include mutual information, normalised mutual information, and the Jensen-Shannon divergence. These metrics provide a measure of the information shared between two images, making them well suited for assessing the quality of image registration.
The following function takes a list of registered images, as well as the template image, and calculates the above metrics for each image:
Given these losses, it’s probably a good idea to have some sort of function that shows us the best and worst registered images, based on the loss. This is somewhat similar to viewing individual examples from a confusion matrix in a classification task.
Code
def visualise_registration_results(registered_images, original_images, template, loss_values):# Get the indices of the three images with the highest L1 losses top_indices = np.argsort(loss_values)[-3:]# Get the indices of the three images with the lowest L1 losses bottom_indices = np.argsort(loss_values)[:3]# Create the grid figure fig, axes = plt.subplots(3, 4, figsize=(20, 15)) fig.subplots_adjust(hspace=0.4, wspace=0.4)# Loop through the top three imagesfor i, idx inenumerate(top_indices):# Plot the original image in the first column of the left section ax = axes[i][0] ax.imshow(original_images[idx], cmap='gray') ax.set_title("Original Image")# Plot the registered image in the second column of the left section ax = axes[i][1] ax.imshow(registered_images[idx], cmap='gray') ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))# Loop through the bottom three imagesfor i, idx inenumerate(bottom_indices):# Plot the original image in the first column of the right section ax = axes[i][2] ax.imshow(original_images[idx], cmap='gray') ax.set_title("Original Image")# Plot the registered image in the second column of the right section ax = axes[i][3] ax.imshow(registered_images[idx], cmap='gray') ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))# Show the grid plt.show()
Exploratory data analysis
Code
# Define the path to the directory containing the fundus imagesfundus_dir = Path("../data/156518")# Filter the files in the directory to only include PNG filesfundus_images = [x for x in fundus_dir.iterdir() if x.is_file() and x.suffix ==".png"]fundus_images = [x for x in fundus_images if"_L.png"in x.name]# Get the first fundus imagefirst_image = fundus_images[0]# Load the first image using Matplotlib's imread functionimage = plt.imread(first_image)# Display the first imageplt.imshow(image, cmap='gray')plt.show()
Code
# Define the number of rows and columns for the subplotsnrows, ncols =3, 3# Create the subplotsfig, axs = plt.subplots(nrows, ncols, figsize=(15, 15))axs = axs.ravel()# Load and display the first 9 imagesfor i inrange(9): image = plt.imread(fundus_images[i]) axs[i].imshow(image, cmap='gray') axs[i].axis("off")plt.tight_layout()plt.show()
Code
# Set the root directory for the patient dataroot_dir = Path('../data/156518')# Get the list of image filenames for the left eyeimage_filenames = [f for f in os.listdir(root_dir) if'L.png'in f]# Read the images into a listimages = [cv2.imread(str(root_dir / f)) for f in image_filenames]# Convert the images to grayscalegray_images = [rgb2gray(img) for img in images]# Register all images to the first imagetemplate = gray_images[0]# Remove invalid imagesfinal_images = [x for x in gray_images[1:] if x.shape == template.shape]# Do the registration processregistered_images = []for i, img inenumerate(tqdm(final_images)):# calculate the vector field for optical flow v, u = optical_flow_tvl1(template, img)# use the estimated optical flow for registration nr, nc = template.shape row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc), indexing='ij') registered = warp(img, np.array([row_coords + v, col_coords + u]), mode='edge')#registered = cv2.warpAffine(img, flow, (img.shape[1], img.shape[0])) registered_images.append(registered)
def visualise_registration_results(registered_images, original_images, template, loss_values):# Get the indices of the three images with the highest L1 losses top_indices = np.argsort(loss_values)[-3:]# Get the indices of the three images with the lowest L1 losses bottom_indices = np.argsort(loss_values)[:3]# Create the grid figure fig, axes = plt.subplots(3, 4, figsize=(20, 15)) fig.subplots_adjust(hspace=0.4, wspace=0.4)# Loop through the top three imagesfor i, idx inenumerate(top_indices):# Plot the original image in the first column of the left section ax = axes[i][0] ax.imshow(original_images[idx], cmap='gray') ax.set_title("Original Image")# Plot the registered image in the second column of the left section ax = axes[i][1] ax.imshow(registered_images[idx], cmap='gray') ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))# Loop through the bottom three imagesfor i, idx inenumerate(bottom_indices):# Plot the original image in the first column of the right section ax = axes[i][2] ax.imshow(original_images[idx], cmap='gray') ax.set_title("Original Image")# Plot the registered image in the second column of the right section ax = axes[i][3] ax.imshow(registered_images[idx], cmap='gray') ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))# Show the grid plt.show()
def retrieve_images(patient_id ='156518'):# Set the root directory for the patient data root_dir = Path(f'../data/{patient_id}')# Get the list of image filenames for the left eye image_filenames = [f for f in os.listdir(root_dir) if'L.png'in f]# Read the images into a list images = [cv2.imread(str(root_dir / f)) for f in image_filenames]# Convert the images to grayscale gray_images = [rgb2gray(img) for img in images]# Register all images to the first image template = gray_images[0]# Remove invalid images final_images = [x for x in gray_images[1:] if x.shape == template.shape]return final_images, template
Code
def optical_flow(template, img):# calculate the vector field for optical flow v, u = optical_flow_tvl1(template, img)# use the estimated optical flow for registration nr, nc = template.shape row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc), indexing='ij') registered = warp(img, np.array([row_coords + v, col_coords + u]), mode='edge')return registered
The main idea is to solve a pairwise optimisation problem by minimising the cost function C. The optimisation can be formulated as \hat{T} = \text{argmin}_T C(T, I_f, I_m) with cost function defined as C(T, I_f, I_m) = -S(T, I_f, I_m) + \gamma P(T) where T is the transformation matrix, S is the similarity measurement and P is the penalty term with regulariser parameter \gamma.
SimpleElastix is based on the parametric approach to solve the optimisation problem, where the number of possible transformations are limited by introducing a parametrisation (model) of the transform. The optimisation becomes \hat{T}_\mu = \text{argmin}_{T_\mu} C(T_\mu, I_f, I_m)T_\mu denotes the parametrisation model and vector \mu contains the values of the transformation parameters. For 2D rigid transformation, the parameter vector \mu contains one rotation angle and the translation in x and y direction.
Code
import SimpleITK as sitkdef simple_elastix(image, template):# Convert the input images to SimpleITK images moving_image = sitk.GetImageFromArray(image) fixed_image = sitk.GetImageFromArray(template)# Create the registration method registration_method = sitk.DemonsRegistrationFilter()# Set the parameters registration_method.SetNumberOfIterations(100) registration_method.SetStandardDeviations(0.01)# Execute the registration registered_image = registration_method.Execute(fixed_image, moving_image)# Convert the result to a numpy array registered_image = sitk.GetArrayFromImage(registered_image)# Extract the first component (displacement field) from the registered image registered_image = np.mean(registered_image, axis=2)return registered_image
And now we preprocess the images as before, getting the template and the images to be registered:
Code
# retrieve images to be registered, and the image to register toimages, template = retrieve_images()# perform the registration using SimpleElastixopt = RegistrationAlgorithm(simple_elastix)l1_losses, ncc_values, ssim_values = opt.evaluate_registration()print("L1 losses:", f"{np.mean(l1_losses):.2f}")print("Normalized cross-correlation values:", f"{np.mean(ncc_values):.2f}")print("Structural similarity index values:", f"{np.mean(ssim_values):.2f}")